-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[ROCDL] Added rocdl.cvt.scale.pk8
ops
#161411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: None (ravil-mobile) ChangesThis patch introduces some missing FP conversion instructions in the ROCDL dialect Specifically:
Tests:
Full diff: https://github.com/llvm/llvm-project/pull/161411.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 8b687a7f29bef..ff78c7f80fe07 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -985,7 +985,6 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
//===---------------------------------------------------------------------===//
// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//
-
foreach smallT = [
ScaleArgInfo<I32, "Fp4">,
ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
@@ -996,6 +995,8 @@ foreach smallT = [
ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V8F32Type, "F32">,
] in {
+
+ // Up-scaling
def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
@@ -1010,13 +1011,30 @@ foreach smallT = [
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
}];
}
+
+ // Down-scaling
+ def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name ;
+ let description = [{
+ Convert 8 packed }] # smallT.name # [{ values to packed }]
+ # largeT.name # [{, multiplying by the exponent part of `scale`
+ before doing so.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
} // foreach largeT
} // foreach smallTOp
//===---------------------------------------------------------------------===//
// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//
-
foreach smallT = [
ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e043a8c533d05..00ee6b795c43a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
llvm.return
}
+// CHECK-LABEL: rocdl.cvt.scalef32.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) {
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
+
+ llvm.return
+}
+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add print/parse tests in mlir/test/Dialect/LLVMIR, otherwise LGTM
276fb5a
to
f64803a
Compare
Also, before merging, change the title to reflect the operation's name you added accurately. I.e., instead of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just the small corrections.
6338f91
to
cf200fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch introduces some missing FP conversion instructions in the ROCDL dialect Specifically: - Downscaling 8x packed F16, Bf16, Fp32 values to Fp8, Bf8, Fp4 Tests: - Added lit-tests to check MLIR -> LLVM lowering
This patch introduces some missing FP conversion instructions in the ROCDL dialect
Specifically:
Tests: